{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tensorflow - TensorRT Inference Example from Checkpoint\n",
    "In this notebook, we demonstrate the process to create a TF-TRT optimized model  for inference from a Tensorflow *checkpoint*. We will also validate the resulting models, both on accuracy and speed.\n",
    "\n",
    "## Notebook  Content\n",
    "1. [Pre-requisite: data and model](#1)\n",
    "1. [Verifying the orignal FP32 model](#2)\n",
    "1. [Creating TF-TRT FP32 model](#3)\n",
    "1. [Creating TF-TRT FP16 model](#4)\n",
    "1. [Creating TF-TRT INT8 model](#5)\n",
    "1. [Calibrating TF-TRT INT8 model with raw JPEG images](#6)\n",
    "\n",
    "This notebook has been successfully tested in the NVIDIA NGC Tensorflow container `nvcr.io/nvidia/tensorflow:19.04-py3` that can be downloaded from http://ngc.nvidia.com.\n",
    "\n",
    "## Quickstart\n",
    "\n",
    "We wil be using the ImageNet dataset in TFrecords format. Google provides an excellent all-in-one script for downloading and preparing the ImageNet dataset at \n",
    "\n",
    "https://github.com/tensorflow/models/blob/master/research/inception/inception/data/download_and_preprocess_imagenet.sh.\n",
    "\n",
    "We will run this demonstration with a saved model from the Tensorflow Slim model zoo \n",
    "\n",
    "https://github.com/tensorflow/models/tree/master/research/slim\n",
    "\n",
    "To run this notebook, start the NGC TF container providing correct path to ImageNet validation data and a TF Slim saved checkpoint:\n",
    "\n",
    "```bash\n",
    "nvidia-docker run -it -p 8888:8888 -v /path/to/image_net/:/data  -v /path/to/saved_model:/saved_model --name TFTRT nvcr.io/nvidia/tensorflow:19.04-py3\n",
    "```\n",
    "Then start Jupyter notebook within the container with:\n",
    "\n",
    "```bash\n",
    "jupyter notebook --ip 0.0.0.0 --port 8888  --allow-root\n",
    "```\n",
    "\n",
    "Connect to Jupyter notebook web interface from your local host http://localhost:8888. \n",
    "\n",
    "<a id=\"1\"></a>\n",
    "## 1. Pre-requisite: data and model\n",
    "\n",
    "We first install some extra packages and external dependencies. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "pushd /workspace/nvidia-examples/tensorrt/tftrt/examples/object_detection\n",
    "bash install_dependencies.sh;\n",
    "popd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ii  libnvinfer-dev                         5.1.2-1+cuda10.1                      amd64        TensorRT development libraries and headers\r\n",
      "ii  libnvinfer5                            5.1.2-1+cuda10.1                      amd64        TensorRT runtime libraries\r\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow.contrib.tensorrt as trt\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import time\n",
    "import logging\n",
    "\n",
    "logging.getLogger(\"tensorflow\").setLevel(logging.ERROR)\n",
    "\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0'\n",
    "\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth=True\n",
    "\n",
    "#check TensorRT version\n",
    "!dpkg -l | grep nvinfer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data\n",
    "We first check that the correct Imagenet validation data folder has been mounted. In this experiment, we shall employ the Imagenet validation data to verify the accuracy and inference speed of the models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_files(data_dir, filename_pattern):\n",
    "    if data_dir == None:\n",
    "        return []\n",
    "    files = tf.gfile.Glob(os.path.join(data_dir, filename_pattern))\n",
    "    if files == []:\n",
    "        raise ValueError('Can not find any files in {} with '\n",
    "                         'pattern \"{}\"'.format(data_dir, filename_pattern))\n",
    "    return files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 128 calibration files. \n",
      "/data/validation-00114-of-00128\n",
      "/data/validation-00094-of-00128\n",
      "...\n"
     ]
    }
   ],
   "source": [
    "VALIDATION_DATA_DIR = \"/data\"\n",
    "calibration_files = get_files(VALIDATION_DATA_DIR, 'validation*')\n",
    "print('There are %d calibration files. \\n%s\\n%s\\n...'%(len(calibration_files), calibration_files[0], calibration_files[-1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TF model checkpoint\n",
    "If not already downloaded, we will be downloading and working with a ResNet-50 v1 checkpoint from https://github.com/tensorflow/models/tree/master/research/slim. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "FILE=/saved_model/resnet_v1_50_2016_08_28.tar.gz\n",
    "if [ -f $FILE ]; then\n",
    "   echo \"The file '$FILE' exists.\"\n",
    "else\n",
    "   echo \"The file '$FILE' in not found. Downloading...\"\n",
    "   wget -P /saved_model/ http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz\n",
    "fi\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "resnet_v1_50.ckpt\r\n"
     ]
    }
   ],
   "source": [
    "!tar -xzvf /saved_model/resnet_v1_50_2016_08_28.tar.gz -C /saved_model "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions\n",
    "We define a few helper functions to read and preprocess Imagenet data from TFRecord files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Define some global variables\n",
    "BATCH_SIZE = 8\n",
    "SAVED_MODEL_DIR = \"/saved_model/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def deserialize_image_record(record):\n",
    "    feature_map = {\n",
    "        'image/encoded':          tf.FixedLenFeature([ ], tf.string, ''),\n",
    "        'image/class/label':      tf.FixedLenFeature([1], tf.int64,  -1),\n",
    "        'image/class/text':       tf.FixedLenFeature([ ], tf.string, ''),\n",
    "        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),\n",
    "        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),\n",
    "        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),\n",
    "        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32)\n",
    "    }\n",
    "    with tf.name_scope('deserialize_image_record'):\n",
    "        obj = tf.parse_single_example(record, feature_map)\n",
    "        imgdata = obj['image/encoded']\n",
    "        label   = tf.cast(obj['image/class/label'], tf.int32)\n",
    "        bbox    = tf.stack([obj['image/object/bbox/%s'%x].values\n",
    "                            for x in ['ymin', 'xmin', 'ymax', 'xmax']])\n",
    "        bbox = tf.transpose(tf.expand_dims(bbox, 0), [0,2,1])\n",
    "        text    = obj['image/class/text']\n",
    "        return imgdata, label, bbox, text\n",
    "\n",
    "from preprocessing import vgg_preprocessing\n",
    "def preprocess(record):\n",
    "    # Parse TFRecord\n",
    "    imgdata, label, bbox, text = deserialize_image_record(record)\n",
    "    label -= 1 # Change to 0-based (don't use background class)\n",
    "    try:    image = tf.image.decode_jpeg(imgdata, channels=3, fancy_upscaling=False, dct_method='INTEGER_FAST')\n",
    "    except: image = tf.image.decode_png(imgdata, channels=3)\n",
    "\n",
    "    image = vgg_preprocessing.preprocess_image(image, 224, 224, is_training=False)\n",
    "    return image, label\n",
    "\n",
    "dataset = tf.data.TFRecordDataset(calibration_files)    \n",
    "dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess, batch_size=BATCH_SIZE, num_parallel_calls=20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next are two functions to benchmark models speed and accuracy, either in a `graph_def` form or a `saved model` form."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark_frozen_graph(frozen_graph, SAVED_MODEL_DIR=None, dataset=dataset, BATCH_SIZE=8):\n",
    "    with tf.Session(graph=tf.Graph(), config=config) as sess:\n",
    "        # prepare dataset iterator\n",
    "        iterator = dataset.make_one_shot_iterator()\n",
    "        next_element = iterator.get_next()\n",
    "\n",
    "        output_node = tf.import_graph_def(\n",
    "            frozen_graph,\n",
    "            return_elements=['classes'],\n",
    "            name=\"\")\n",
    "        \n",
    "        print('Warming up for 50 batches...')\n",
    "        for _ in range (50):\n",
    "            sess.run(['classes:0'], feed_dict={\"input:0\": sess.run(next_element)[0]})\n",
    "\n",
    "        num_hits = 0\n",
    "        num_predict = 0\n",
    "        start_time = time.time()\n",
    "        try:\n",
    "            while True:        \n",
    "                image_data = sess.run(next_element)    \n",
    "                img = image_data[0]\n",
    "                label = image_data[1].squeeze()\n",
    "                output = sess.run(['classes:0'], feed_dict={\"input:0\": img})\n",
    "                prediction = output[0]\n",
    "                num_hits += np.sum(prediction == label)\n",
    "                num_predict += len(prediction)\n",
    "        except tf.errors.OutOfRangeError as e:\n",
    "            pass\n",
    "\n",
    "        print('Accuracy: %.2f%%'%(100*num_hits/num_predict)) \n",
    "        print('Inference speed: %.2f samples/s'%(num_predict/(time.time()-start_time)))\n",
    "        \n",
    "        #Optionally, save model for serving if an ouput directory argument is presented\n",
    "        if SAVED_MODEL_DIR:\n",
    "            print('Saving model to %s'%SAVED_MODEL_DIR)\n",
    "            tf.saved_model.simple_save(\n",
    "                session=sess,\n",
    "                export_dir=SAVED_MODEL_DIR,\n",
    "                inputs={\"input\":tf.get_default_graph().get_tensor_by_name(\"input:0\")},\n",
    "                outputs={\"classes\":tf.get_default_graph().get_tensor_by_name(\"classes:0\")},\n",
    "                legacy_init_op=None\n",
    "             )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark_saved_model(SAVED_MODEL_DIR, dataset=dataset, BATCH_SIZE=8):\n",
    "    with tf.Session(graph=tf.Graph(), config=config) as sess:\n",
    "        # prepare dataset iterator\n",
    "        iterator = dataset.make_one_shot_iterator()\n",
    "        next_element = iterator.get_next()\n",
    "\n",
    "        tf.saved_model.loader.load(\n",
    "            sess, [tf.saved_model.tag_constants.SERVING], SAVED_MODEL_DIR)\n",
    "\n",
    "        print('Warming up for 50 batches...')\n",
    "        for _ in range (50):\n",
    "            sess.run(['classes:0'], feed_dict={\"input:0\": sess.run(next_element)[0]})\n",
    "\n",
    "        print('Benchmarking inference engine...')\n",
    "        num_hits = 0\n",
    "        num_predict = 0\n",
    "        start_time = time.time()\n",
    "        try:\n",
    "            while True:        \n",
    "                image_data = sess.run(next_element)    \n",
    "                img = image_data[0]\n",
    "                label = image_data[1].squeeze()\n",
    "                output = sess.run(['classes:0'], feed_dict={\"input:0\": img})            \n",
    "                prediction = output[0]\n",
    "                num_hits += np.sum(prediction == label)\n",
    "                num_predict += len(prediction)\n",
    "        except tf.errors.OutOfRangeError as e:\n",
    "            pass\n",
    "\n",
    "        print('Accuracy: %.2f%%'%(100*num_hits/num_predict))\n",
    "        print('Inference speed: %.2f samples/s'%(num_predict/(time.time()-start_time)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"2\"></a>\n",
    "## 2. Verifying the orignal FP32 model\n",
    "\n",
    "We first load and benchmark the Resnet-v1-50 model from TF slim. Note that the checkpoint downloaded from http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz doesn't come with any meta data, therefore we will need to employ TF Slim Net factory to get the model definition. The newer checkpoints generated by Tensorflow generally comes with enough meta information to load the network from the achirve. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nets.nets_factory\n",
    "\n",
    "graph = tf.Graph()\n",
    "with graph.as_default():\n",
    "    with tf.Session(config=config) as sess:\n",
    "        tf_input = tf.placeholder(tf.float32, [None, 224, 224, 3], name='input')\n",
    "        network_fn = nets.nets_factory.get_network_fn('resnet_v1_50', 1000,\n",
    "                                                      is_training=False)\n",
    "        tf_net, tf_end_points = network_fn(tf_input)\n",
    "                \n",
    "        saver = tf.train.Saver()\n",
    "        saver.restore(sess, SAVED_MODEL_DIR+\"resnet_v1_50.ckpt\")\n",
    "        \n",
    "        tf_output = tf.identity(tf_net, name='logits')\n",
    "        tf_output_classes = tf.argmax(tf_output, axis=1, name='classes')        \n",
    "        #tf_output_classes = tf.reshape(tf_output_classes, (BATCH_SIZE,), name='classes')\n",
    "        \n",
    "        # freeze graph\n",
    "        fp32_frozen_graph = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            sess.graph_def,\n",
    "            output_node_names=['logits', 'classes']\n",
    "        ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FP32_SAVED_MODEL_DIR = SAVED_MODEL_DIR+\"model/Resnet_FP32/1\"\n",
    "!rm -rf $FP32_SAVED_MODEL_DIR\n",
    "\n",
    "benchmark_frozen_graph(fp32_frozen_graph, FP32_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:\n",
      "\n",
      "signature_def['serving_default']:\n",
      "  The given SavedModel SignatureDef contains the following input(s):\n",
      "    inputs['input'] tensor_info:\n",
      "        dtype: DT_FLOAT\n",
      "        shape: (-1, 224, 224, 3)\n",
      "        name: input:0\n",
      "  The given SavedModel SignatureDef contains the following output(s):\n",
      "    outputs['classes'] tensor_info:\n",
      "        dtype: DT_INT64\n",
      "        shape: (-1)\n",
      "        name: classes:0\n",
      "  Method name is: tensorflow/serving/predict\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --all --dir $FP32_SAVED_MODEL_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_saved_model(FP32_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"3\"></a>\n",
    "## 3. Creating TF-TRT FP32 model\n",
    "\n",
    "Next, we convert the naitive TF FP32 model to TF-TRT FP32, then verify model accuracy and inference speed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "trt_fp32_graph = trt.create_inference_graph(\n",
    "    input_graph_def=fp32_frozen_graph,\n",
    "    outputs=['classes'],\n",
    "    max_batch_size=BATCH_SIZE,\n",
    "    precision_mode=\"FP32\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRT_FP32_SAVED_MODEL_DIR = SAVED_MODEL_DIR+\"model/Resnet_TRT_FP32/1\"\n",
    "!rm -rf $TRT_FP32_SAVED_MODEL_DIR\n",
    "\n",
    "benchmark_frozen_graph(trt_fp32_graph, TRT_FP32_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:\n",
      "\n",
      "signature_def['serving_default']:\n",
      "  The given SavedModel SignatureDef contains the following input(s):\n",
      "    inputs['input'] tensor_info:\n",
      "        dtype: DT_FLOAT\n",
      "        shape: (-1, 224, 224, 3)\n",
      "        name: input:0\n",
      "  The given SavedModel SignatureDef contains the following output(s):\n",
      "    outputs['classes'] tensor_info:\n",
      "        dtype: DT_INT64\n",
      "        shape: unknown_rank\n",
      "        name: classes:0\n",
      "  Method name is: tensorflow/serving/predict\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --all --dir $TRT_FP32_SAVED_MODEL_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_saved_model(TRT_FP32_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"4\"></a>\n",
    "## 4. Creating TF-TRT FP16 model\n",
    "\n",
    "\n",
    "Next, we convert the naitive TF FP32 model to TF-TRT FP16, then verify model accuracy and inference speed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "trt_fp16_graph = trt.create_inference_graph(\n",
    "    input_graph_def=fp32_frozen_graph,\n",
    "    outputs=['classes'],\n",
    "    max_batch_size=BATCH_SIZE,\n",
    "    precision_mode=\"FP16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRT_FP16_SAVED_MODEL_DIR = SAVED_MODEL_DIR+\"/model/Resnet_TRT_FP16/1\"\n",
    "!rm -rf $TRT_FP16_SAVED_MODEL_DIR\n",
    "\n",
    "benchmark_frozen_graph(trt_fp16_graph, TRT_FP16_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_saved_model(TRT_FP16_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"5\"></a>\n",
    "## 5. Creating TF-TRT INT8 model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Creating TF-TRT INT8 inference model requires two steps:\n",
    "\n",
    "- Step 1: creating the calibration graph, and run some data through that graph for INT-8 calibration.\n",
    "\n",
    "- Step 2: converting the calibration graph to the TF-TRT INT8 inference engine\n",
    "\n",
    "### Step 1: Creating the calibration graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calibrate model on calibration data...\n"
     ]
    }
   ],
   "source": [
    "#Now we create the TFTRT INT8 calibration graph\n",
    "trt_int8_calib_graph = trt.create_inference_graph(\n",
    "        input_graph_def=fp32_frozen_graph,\n",
    "        outputs=['classes:0'],\n",
    "        max_batch_size=BATCH_SIZE,\n",
    "        max_workspace_size_bytes=1<<32,\n",
    "        precision_mode='INT8')\n",
    "\n",
    "#Then calibrate it with 50 batches of examples\n",
    "N_runs=50\n",
    "with tf.Session(graph=tf.Graph()) as sess:\n",
    "    \n",
    "    output_node = tf.import_graph_def(\n",
    "        trt_int8_calib_graph,\n",
    "        return_elements=['classes'],\n",
    "        name='')\n",
    "    \n",
    "    # Prepare data set iterator\n",
    "    iterator = dataset.make_one_shot_iterator()\n",
    "    next_element = iterator.get_next()\n",
    "\n",
    "    print('Calibrate model on calibration data...')\n",
    "    for _ in range(N_runs):            \n",
    "            prediction = sess.run(output_node, feed_dict={\"input:0\": sess.run(next_element[0])})                        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Converting the calibration graph to inference graph\n",
    "\n",
    "Now we are ready to convert the INT8 calibration graph to the final TF-TRT INT8 inference engine, and benchmark its performance. We will also be saving this engine to a *saved model*, ready to be served elsewhere."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Int8 inference model from the calibration graph\n",
    "trt_int8_calibrated_graph=trt.calib_graph_to_infer_graph(trt_int8_calib_graph)\n",
    "output_node = tf.import_graph_def(\n",
    "        trt_int8_calibrated_graph,\n",
    "        return_elements=['classes'],\n",
    "        name='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INT8_SAVED_MODEL_DIR = SAVED_MODEL_DIR + '/model/Resnet_TRT_INT8/1'\n",
    "!rm -rf $INT8_SAVED_MODEL_DIR\n",
    "\n",
    "benchmark_frozen_graph(trt_int8_calibrated_graph, INT8_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:\n",
      "\n",
      "signature_def['serving_default']:\n",
      "  The given SavedModel SignatureDef contains the following input(s):\n",
      "    inputs['input'] tensor_info:\n",
      "        dtype: DT_FLOAT\n",
      "        shape: (-1, 224, 224, 3)\n",
      "        name: input:0\n",
      "  The given SavedModel SignatureDef contains the following output(s):\n",
      "    outputs['classes'] tensor_info:\n",
      "        dtype: DT_INT64\n",
      "        shape: unknown_rank\n",
      "        name: classes:0\n",
      "  Method name is: tensorflow/serving/predict\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --all --dir $INT8_SAVED_MODEL_DIR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally we reload and verify the performance of the INT8 saved model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_saved_model(INT8_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"6\"></a>\n",
    "## 6. Calibrating TF-TRT INT8 model with raw JPEG images\n",
    "\n",
    "As an alternative to taking data in TFRecords format, in this section, we demonstrate the process of calibrating TFTRT INT-8 model from a directory of raw JPEG images. We asume that raw images have been mounted to the directory `/data/Calibration_data`.\n",
    "\n",
    "As a rule of thumb, calibration data should be a small but representative set of images that is similar to what is expected in deployment. Empirically, for common network architectures trained on imagenet data, calibration data of size 500-1000 provide good accuracy. As such, a good strategy for a dataset such as imagenet is to choose one sample from each class. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 1000 calibration files. \n",
      "/data/Calibration_data/n03710193_29586.JPEG\n",
      "/data/Calibration_data/n02085782_2655.JPEG\n",
      "...\n"
     ]
    }
   ],
   "source": [
    "data_directory = \"/data/Calibration_data\"\n",
    "calibration_files = [os.path.join(path, name) for path, _, files in os.walk(data_directory) for name in files]\n",
    "print('There are %d calibration files. \\n%s\\n%s\\n...'%(len(calibration_files), calibration_files[0], calibration_files[-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_file(filepath):\n",
    "    image = tf.read_file(filepath)\n",
    "    image = tf.image.decode_jpeg(image, channels=3)\n",
    "    image = vgg_preprocessing.preprocess_image(image, 224, 224, is_training=False)\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices(calibration_files)\n",
    "dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=parse_file, batch_size=BATCH_SIZE, num_parallel_calls=20))\n",
    "dataset = dataset.repeat(count=1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we proceed with the two-stage process of creating and calibrating TFTRT INT8 model.\n",
    "\n",
    "### Step 1: Creating the calibration graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calibrate model on calibration data...\n"
     ]
    }
   ],
   "source": [
    "#Now we create the TFTRT INT8 calibration graph\n",
    "trt_int8_calib_graph = trt.create_inference_graph(\n",
    "        input_graph_def=fp32_frozen_graph,\n",
    "        outputs=['classes:0'],\n",
    "        max_batch_size=BATCH_SIZE,\n",
    "        max_workspace_size_bytes=1<<32,\n",
    "        precision_mode='INT8')\n",
    "\n",
    "#Then calibrate it with 50 batches of examples\n",
    "N_runs=50\n",
    "with tf.Session(graph=tf.Graph()) as sess:\n",
    "    \n",
    "    output_node = tf.import_graph_def(\n",
    "        trt_int8_calib_graph,\n",
    "        return_elements=['classes'],\n",
    "        name='')\n",
    "    \n",
    "    # Prepare data set iterator\n",
    "    iterator = dataset.make_one_shot_iterator()\n",
    "    next_element = iterator.get_next()\n",
    "\n",
    "    print('Calibrate model on calibration data...')\n",
    "    for _ in range(N_runs):            \n",
    "            prediction = sess.run(output_node, feed_dict={\"input:0\": sess.run(next_element)})                        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Converting the calibration graph to inference graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Int8 inference model from the calibration graph\n",
    "trt_int8_calibrated_graph=trt.calib_graph_to_infer_graph(trt_int8_calib_graph)\n",
    "output_node = tf.import_graph_def(\n",
    "        trt_int8_calibrated_graph,\n",
    "        return_elements=['classes'],\n",
    "        name='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INT8_SAVED_MODEL_DIR = SAVED_MODEL_DIR + '/model/Resnet_TRT_INT8_JPEG/1'\n",
    "!rm -rf $INT8_SAVED_MODEL_DIR\n",
    "\n",
    "benchmark_frozen_graph(trt_int8_calibrated_graph, INT8_SAVED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "In this notebook, we have demonstrated the process of creating TF-TRT inference model from an original TF FP32 checkpoint. In every case, we have also verified the accuracy and speed to the resulting model. \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
